#!/usr/bin/env python3
"""Script for generating a TileConnectionArray test case Vector."""
from collections import defaultdict
import re

SOURCE_FILE_PATH = "src/tiling/tiles/tests/tile_connection_array.rs"
CASES = [
    {
        "name": "zero tiles",
        "tiles": [],
        "horizontal connector count": 1,
        "vertical connector count": 1,
    },
    {
        "name": "one tile, all connectors match their counterpart",
        "tiles": [{"name": 0, "connectors": (1, 1, 1, 1)}],
        "horizontal connector count": 2,
        "vertical connector count": 2,
    },
    {
        "name": "one tile, no connectors match their counterpart",
        "tiles": [{"name": 1, "connectors": (1, 2, 2, 1)}],
        "horizontal connector count": 3,
        "vertical connector count": 3,
    },
    {
        "name": "one tile, one connector matches its counterpart",
        "tiles": [{"name": 2, "connectors": (1, 1, 2, 1)}],
        "horizontal connector count": 2,
        "vertical connector count": 3,
    },
    {
        "name": "two tiles",
        "tiles": [
            {"name": 1, "connectors": (1, 2, 2, 1)},
            {"name": 3, "connectors": (2, 3, 1, 4)},
        ],
        "horizontal connector count": 5,
        "vertical connector count": 3,
    },
    {
        "name": "three tiles",
        "tiles": [
            {"name": 1, "connectors": (1, 2, 2, 1)},
            {"name": 3, "connectors": (2, 3, 1, 4)},
            {"name": 4, "connectors": (3, 3, 2, 4)},
        ],
        "horizontal connector count": 5,
        "vertical connector count": 4,
    },
]


# noinspection PyTypeChecker
def main():
    """Take the list of CASES and generate rstest cases for TileConnectionArray."""
    output = ""
    for case in CASES:
        tile_connections = defaultdict(list)
        for tile in case["tiles"]:
            for north in (0, tile["connectors"][0]):
                for east in (0, tile["connectors"][1]):
                    for south in (0, tile["connectors"][2]):
                        for west in (0, tile["connectors"][3]):
                            tile_connections[(north, east, south, west)].append(
                                tile["name"]
                            )

        # Output
        case_name_normalized = case["name"].replace(" ", "_").replace(",", "")
        output += (
            f"// region {case['name']}\n#[case::{case_name_normalized}(\n    vec!["
        )
        if len(case["tiles"]) > 2:
            output += "\n        "
        output += (",\n        " if len(case["tiles"]) > 2 else ", ").join(
            f"TILE_CONNECTION_TILES[{tile['name']}].clone()" for tile in case["tiles"]
        )
        if len(case["tiles"]) > 2:
            output += ",\n    "
        output += f"""],
    {case['horizontal connector count']},
    {case['vertical connector count']},
    vec![\n"""

        for north in range(case["vertical connector count"]):
            for east in range(case["horizontal connector count"]):
                for south in range(case["vertical connector count"]):
                    for west in range(case["horizontal connector count"]):
                        tiles = tile_connections[(north, east, south, west)]
                        output += f"""\
        // tile_connection_array[{north}, {east}, {south}, {west}]
        TileConnection {{
            tiles: vec!"""
                        tile_indices = []
                        for tile_id in tiles:
                            for tile_index, tile in enumerate(case["tiles"]):
                                if tile["name"] == tile_id:
                                    tile_indices.append(tile_index)
                        output += f"{tile_indices},\n            weights: vec!["
                        if len(tiles) > 1:
                            output += "\n                "
                        output += (
                            ",\n                " if len(tiles) > 1 else ", "
                        ).join(
                            f"TILE_CONNECTION_TILES[{tile}].weight"
                            for tile in tiles
                        )
                        if len(tiles) > 1:
                            output += "\n            "
                        output += "],\n        },\n"
        output += "    ]\n)]\n// endregion\n"

    with open(SOURCE_FILE_PATH) as source_file:
        file_content = source_file.read()

    file_content = re.sub(
        r"(?s)(// generated tests start\n).*?\n(// generated tests end)",
        f"\\1{output}\\2",
        file_content,
    )

    with open(SOURCE_FILE_PATH, "w") as source_file:
        source_file.write(file_content)


if __name__ == "__main__":
    main()
